# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from mmcls.models.backbones.beit import BEiTTransformerEncoderLayer
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule

from mmselfsup.registry import MODELS


@MODELS.register_module()
class BEiTV2Neck(BaseModule):
    """Neck for BEiTV2 Pre-training.

    This module construct the decoder for the final prediction.

    Args:
        num_layers (int): Number of encoder layers of neck. Defaults to 2.
        early_layers (int): The layer index of the early output from the
            backbone. Defaults to 9.
        backbone_arch (str): Vision Transformer architecture. Defaults to base.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
        layer_scale_init_value (float): The initialization value for the
            learnable scaling of attention and FFN. Defaults to 0.1.
        use_rel_pos_bias (bool): Whether to use unique relative position bias,
            if False, use shared relative position bias defined in backbone.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """
    arch_zoo = {
        **dict.fromkeys(
            ['b', 'base'], {
                'embed_dims': 768,
                'depth': 12,
                'num_heads': 12,
                'feedforward_channels': 3072,
            }),
        **dict.fromkeys(
            ['l', 'large'], {
                'embed_dims': 1024,
                'depth': 24,
                'num_heads': 16,
                'feedforward_channels': 4096,
            }),
    }

    def __init__(
        self,
        num_layers: int = 2,
        early_layers: int = 9,
        backbone_arch: str = 'base',
        drop_rate: float = 0.,
        drop_path_rate: float = 0.,
        layer_scale_init_value: float = 0.1,
        use_rel_pos_bias: bool = False,
        norm_cfg: dict = dict(type='LN', eps=1e-6),
        init_cfg: Optional[Union[dict, List[dict]]] = dict(
            type='TruncNormal', layer='Linear', std=0.02, bias=0)
    ) -> None:
        super().__init__(init_cfg=init_cfg)

        if isinstance(backbone_arch, str):
            backbone_arch = backbone_arch.lower()
            assert backbone_arch in set(self.arch_zoo), \
                (f'Arch {backbone_arch} is not in default archs '
                 f'{set(self.arch_zoo)}')
            self.arch_settings = self.arch_zoo[backbone_arch]
        else:
            essential_keys = {
                'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
            }
            assert isinstance(backbone_arch, dict) and essential_keys <= set(
                backbone_arch
            ), f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = backbone_arch

        # stochastic depth decay rule
        self.early_layers = early_layers
        depth = self.arch_settings['depth']
        dpr = np.linspace(0, drop_path_rate,
                          max(depth, early_layers + num_layers))

        self.patch_aggregation = nn.ModuleList()
        for i in range(early_layers, early_layers + num_layers):
            _layer_cfg = dict(
                embed_dims=self.arch_settings['embed_dims'],
                num_heads=self.arch_settings['num_heads'],
                feedforward_channels=self.
                arch_settings['feedforward_channels'],
                drop_rate=drop_rate,
                drop_path_rate=dpr[i],
                norm_cfg=norm_cfg,
                layer_scale_init_value=layer_scale_init_value,
                window_size=None,
                use_rel_pos_bias=use_rel_pos_bias)
            self.patch_aggregation.append(
                BEiTTransformerEncoderLayer(**_layer_cfg))

        self.rescale_patch_aggregation_init_weight()

        embed_dims = self.arch_settings['embed_dims']
        _, norm = build_norm_layer(norm_cfg, embed_dims)
        self.add_module('norm', norm)

    def rescale_patch_aggregation_init_weight(self):
        """Rescale the initialized weights."""

        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.patch_aggregation):
            rescale(layer.attn.proj.weight.data,
                    self.early_layers + layer_id + 1)
            rescale(layer.ffn.layers[1].weight.data,
                    self.early_layers + layer_id + 1)

    def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
                **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get the latent prediction and final prediction.

        Args:
            x (Tuple[torch.Tensor]): Features of tokens.
            rel_pos_bias (torch.Tensor): Shared relative position bias table.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
              - ``x``: The final layer features from backbone, which are normed
                in ``BEiTV2Neck``.
              - ``x_cls_pt``: The early state features from backbone, which are
                consist of final layer cls_token and early state patch_tokens
                from backbone and sent to PatchAggregation layers in the neck.
        """

        early_states, x = inputs[0], inputs[1]
        x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
        for layer in self.patch_aggregation:
            x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)

        # shared norm
        x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)

        # remove cls_token
        x = x[:, 1:]
        x_cls_pt = x_cls_pt[:, 1:]
        return x, x_cls_pt
